using JLD2, Gadfly

include("cenf.jl")

cenf.readData(root)

# Variable theta, z^(2) **************************************************************************************************

println("Welfare counterfactual for Z_F (variable theta)")

f=load("../output/ms_f_output_withvartheta_refined.jld2")
ms_fvals = f["ms_fvals"]
ms_θ = f["ms_θ"]
ms_β = f["ms_β"]
ms_η = f["ms_η"]
ms_γ = f["ms_γ"]
ms_logT = f["ms_logT"]
ms_logS = f["ms_logS"]

idx_best = findfirst(isequal(maximum(ms_fvals)), vec(ms_fvals));

bestθ = ms_θ[idx_best]
bestβ = vec(ms_β[idx_best,:])
bestη = vec(ms_η[idx_best,:])
bestγ = reshape(ms_γ[idx_best,:],35,35);
bestlogT = reshape(ms_logT[idx_best,:],109,35);
bestlogS = reshape(ms_logS[idx_best,:],109,35);

ioshares_before_VF = zeros(109,35,35);
ioshares_after_VF = zeros(109,35,35);
(changeArray_VF,ioshares_before_VF,ioshares_after_VF) = cenf.solveCountries(bestθ, bestβ, bestη, bestγ, 2.0, exp.(bestlogT), exp.(bestlogS),"../output/welfare_us_variabletheta_f","theta estimated, z^(2) used")
plotfilename = "../output/welfare_us_variabletheta_f"
run(`cairosvg $(plotfilename).svg -o $(plotfilename).pdf`)

# compute the counterfactual for different values of theta: *************************************************

δarray = cenf.δ[:,1,1]
θarray = [3.0, 4.0, 5.0, 6.0]
out = Vector{Array{Float64,2}}(undef, length(θarray))
for θindex = 1:length(θarray)
    out[θindex], ioshares_before_VF, ioshares_after_VF = cenf.solveCountries(θarray[θindex], bestβ, bestη, bestγ, 2.0, exp.(bestlogT), exp.(bestlogS),"","")
end
dfPlot1 = DataFrame(delta = δarray,dU = out[1][:,1], Legend = "theta = 3.0")
dfPlot2 = DataFrame(delta = δarray,dU = out[2][:,1], Legend = "theta = 4.0")
dfPlot3 = DataFrame(delta = δarray,dU = out[3][:,1], Legend = "theta = 5.0")
dfPlot4 = DataFrame(delta = δarray,dU = out[4][:,1], Legend = "theta = 6.0")
dfPlot = vcat(dfPlot1, dfPlot2, dfPlot3, dfPlot4)
p = Gadfly.plot(dfPlot, x=:delta, y=:dU , color="Legend", Geom.point, shape="Legend", Guide.title("Estimates using variable theta, z^(2)"),
    Guide.xlabel("Enforcement cost"),
    Guide.ylabel("Counterfactual welfare increase, in percent"),
    Theme(default_color = colorant"red",
           highlight_width = 0pt)
     )
draw(SVG("../output/welfare_us_variabletheta_f-bytheta.svg",15cm, 10cm),p)
run(`cairosvg ../output/welfare_us_variabletheta_f-bytheta.svg -o ../output/welfare_us_variabletheta_f-bytheta.pdf`)
